AI Chess Master - Computer Vision Final¶

Author: Ibrahim Sobh

Instructions:

This notebook includes:

  • Data Exploration ✅
  • Models Training ✅
  • Performance Evaluation ✅
  • Saving Model for Production ✅

Dataset https://www.kaggle.com/code/jayshworkhadka/chess-fen-prediction </br>

Importing Libraries¶

In [ ]:
import re
import cv2
import glob
import random as rd
import warnings
import numpy as np
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")      
from sklearn.metrics import confusion_matrix,f1_score
from sklearn.metrics import accuracy_score, precision_score, recall_score
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
%matplotlib inline

1- Exploratory Data Analysis for data quality review¶

In [ ]:
#Define train path and test path
train_path = glob.glob("../data/train/*.jpeg")
test_path= glob.glob("../data/test/*.jpeg")

#Shuffle train and test samples
rd.shuffle(train_path)
rd.shuffle(test_path)

#Decide the number of train and test samples
train_size =20000
test_size = 4000
train = train_path[:train_size]
test= test_path[:test_size]

# Piece_type = ['King','Queen','Rook','Bishop','Knight','Pawn']|
# Capital = White, Normal = Black
piece_symbols = 'prbnkqPRBNKQ'
  • Define a function to extract labels/FEN from Images
In [ ]:
# Get the labels ( FNE ) for the training and testing images 
def get_image_FEN_label(image_path):
    fen_label= image_path.replace('.jpeg', '').split('/')[-1]
    return fen_label
  • Display a random sample of the data
In [ ]:
rand = np.random.randint(0, train_size)
img_path =train[rand]
img_moves =  get_image_FEN_label(img_path)
img_rand=cv2.imread(img_path)
plt.imshow(cv2.cvtColor(img_rand, cv2.COLOR_BGR2RGB))
plt.title(img_moves)
plt.axis('off')
plt.tight_layout()
plt.show()
  • Check the Python Chess Library to understand the FEN format
In [ ]:
import chess.svg
import chess
print("The FEN notation of the image is: ", img_moves)
board = chess.Board(img_moves.replace('-', '/'))
chess.svg.board(board, size=300)
The FEN notation of the image is:  7B-8-1RK2B2-5NR1-4q2p-1P6-2k2P2-2B2N2
Out[ ]:
. . . . . . . B
. . . . . . . .
. R K . . B . .
. . . . . N R .
. . . . q . . p
. P . . . . . .
. . k . . P . .
. . B . . N . .
  • Display a bunch of samples of the data
In [ ]:
samples =rd.sample(train, 9)
fig = plt.figure(figsize=(11, 11))
columns = 3
rows = 3
for i, img in zip(range(1, columns*rows +1),samples ):
    fig.add_subplot(rows, columns, i)
    img_moves =  get_image_FEN_label(img)
    img = cv2.imread(img)
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    plt.axis('off')
    plt.title(img_moves)
    plt.tight_layout()
    
plt.show()
  • Labels distribution - All labels are Unique
In [ ]:
labels = [get_image_FEN_label(img) for img in train]
set_labels = set(labels)
print("Number of labels {} is equal to number of images {}".format(len(set_labels), len(train)))
Number of labels 20000 is equal to number of images 20000
  • Analyzing pictures dimensions and ratio - Same Ratio, Width, Height for all pictures*
In [ ]:
ratios = []
widths = []
heights = []

for img in train:
  img= cv2.imread(img)
  ratios.append(img.shape[1] / img.shape[0])  
  heights.append(img.shape[0])
  widths.append(img.shape[1])

fig, (ax1,ax2,ax3) = plt.subplots(1, 3, figsize=(12, 5))

ax1.hist(ratios, bins=50)
ax1.set_xlabel('ratio')
ax1.set_ylabel('count')
ax1.set_title('Average ratio: {}'.format(np.mean(ratios)))

ax2.hist(widths, bins=50)
ax2.set_xlabel('width')
ax2.set_ylabel('count')
ax2.set_title('Average width: {}'.format(np.mean(widths)))

ax3.hist(heights, bins=50)
ax3.set_xlabel('height')
ax3.set_ylabel('count')
ax3.set_title('Average height: {}'.format(np.mean(heights)))

print("Selected Width X heights: {}X{}".format(int(np.mean(widths)) ,int(np.mean(heights))))
Selected Width X heights: 400X400

2 - Data Preprocessing - Scaling, Normalization, etc.¶

2.B - Resizing, Scaling, Normalization, etc.¶

  • Create a function to Greyscale,Resize and Normalize the data
In [ ]:
def preprocess_some_images(img_paths, width, height):
  resized_imgs = []
  for img_path in img_paths:
    # change to Grey scal
    # img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    img = cv2.imread(img_path, cv2.COLOR_BGR2GRAY)
  
    # resize the image to the desired size
    gray_image = cv2.resize(img, (width, height))
    
    # Normalize the image
    gray_image =(gray_image - np.min(gray_image)) / (np.max(gray_image) - np.min(gray_image))

    # add to the list
    resized_imgs.append(gray_image)
    
  return resized_imgs
In [ ]:
preprocessed_imgs= preprocess_some_images(samples,240,240)
  • Display new image HxW dimensions
In [ ]:
"{}X{}".format(preprocessed_imgs[0].shape[0],preprocessed_imgs[0].shape[1])
Out[ ]:
'240X240'
  • Display new image shape
In [ ]:
preprocessed_imgs[0].shape
Out[ ]:
(240, 240, 3)
  • Data Diplay - Greyscale, Resized and Normalized
In [ ]:
fig = plt.figure(figsize=(11, 11))
columns = 3
rows = 3
for i, img in zip(range(1, columns*rows +1), preprocessed_imgs):
    fig.add_subplot(rows, columns, i)
    plt.imshow(img, cmap='gray')
    plt.axis('off')
    plt.tight_layout()
plt.show()

2.B - PCA Analysis¶

  1. On the Complete Chess Board
  • Prepare Image Data for PCA Analysis
In [ ]:
img = np.array(cv2.imread(samples[8]))
new_img=img.reshape(img.shape[0], (img.shape[1]*img.shape[2]))
new_img= cv2.resize(new_img, (720, 720),interpolation=cv2.INTER_CUBIC)
plt.figure(figsize=(6, 6))
plt.imshow(new_img, cmap='gray')
new_img.shape
Out[ ]:
(720, 720)
  • PCA Analysis - Components Analysis
In [ ]:
#Import required modules
from sklearn.decomposition import PCA

pca = PCA()
pca.n_components = 15
img_transformed = pca.fit_transform(new_img)

percentage_var_explained = pca.explained_variance_ / np.sum(pca.explained_variance_);

cum_var_explained = np.cumsum(percentage_var_explained)

reserved =np.round(np.sum(pca.explained_variance_ratio_),3)*100

print("Using {} components reservers {}% of the features".format(pca.n_components,reserved))

# Plot the PCA spectrum
plt.figure(1, figsize=(6, 4))

plt.clf()
plt.plot(cum_var_explained, linewidth=2)
plt.axis('tight')
plt.grid()
plt.xlabel('n_components')
plt.ylim(min(cum_var_explained),1)
plt.axhline(y=reserved/100.0, linestyle='--', color='k', linewidth=2)
plt.ylabel('Cumulative_explained_variance')
Using 15 components reservers 90.5% of the features
Out[ ]:
Text(0, 0.5, 'Cumulative_explained_variance')
  • PCA Analysis - Results
In [ ]:
fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(12, 5))
img = np.array(cv2.imread(samples[8]))##np.array(preprocessed_imgs[0],copy=True)
new_img=img.reshape(img.shape[0], (img.shape[1]*img.shape[2]))
new_img= cv2.resize(new_img, (720, 720),interpolation=cv2.INTER_CUBIC)
ax1.imshow(new_img, cmap='gray')
ax1.set_title("Before PCA image")

plt.figure(figsize=(6, 6))
temp = pca.inverse_transform(img_transformed) 
temp = np.reshape(temp, (720,720)) 
ax2.imshow(temp, cmap='gray')
ax2.set_title("After PCA image")
Out[ ]:
Text(0.5, 1.0, 'After PCA image')
<Figure size 432x432 with 0 Axes>

2 . On the Chess Board Pieces

  • Prepare Image Data for PCA Analysis
In [ ]:
def image_to_squares_pca(img,heights,widths):
  squares = []
  for i in range(0,8):
    for j in range(0,8):
      new_img =img[i*heights//8:i*heights//8+heights//8,j*widths//8:j*widths//8+widths//8]
      new_img=new_img.reshape(new_img.shape[0], (new_img.shape[1]*new_img.shape[2]))
      new_img= cv2.resize(new_img, (720, 720),interpolation=cv2.INTER_CUBIC)
      squares.append(new_img)
  return np.array(squares)

img = np.array(cv2.imread(samples[8]))
sqaures= image_to_squares_pca(img,400,400)
fig = plt.figure(figsize=(10, 10))
columns = 8
rows = 8
for i, img in zip(range(1, columns*rows +1),sqaures):
    fig.add_subplot(rows, columns, i)
    plt.imshow(img, cmap='gray')
    plt.axis('off')
    plt.tight_layout()
plt.show()
sqaures.shape
Out[ ]:
(64, 720, 720)
  • PCA Analysis - Components Analysis
In [ ]:
#Import required modules
from sklearn.decomposition import PCA
test_sample= sqaures[3]
pca = PCA()
pca.n_components = 5
img_transformed = pca.fit_transform(test_sample)

percentage_var_explained = pca.explained_variance_ / np.sum(pca.explained_variance_);

cum_var_explained = np.cumsum(percentage_var_explained)

reserved =np.round(np.sum(pca.explained_variance_ratio_),3)*100

print("Using {} components reservers {}% of the features".format(pca.n_components,reserved))

# Plot the PCA spectrum
plt.figure(1, figsize=(6, 4))

plt.clf()
plt.plot(cum_var_explained, linewidth=2)
plt.axis('tight')
plt.grid()
plt.xlabel('n_components')
plt.ylim(min(cum_var_explained),1)
plt.axhline(y=reserved/100.0, linestyle='--', color='k', linewidth=2)
plt.ylabel('Cumulative_explained_variance')
Using 5 components reservers 94.69999999999999% of the features
Out[ ]:
Text(0, 0.5, 'Cumulative_explained_variance')
  • PCA Analysis - Results
In [ ]:
#Import required modules
from sklearn.decomposition import PCA
sqaures_PCA=[]
for img in sqaures:
    pca = PCA()
    pca.n_components = 5
    img_transformed = pca.fit_transform(img)
    temp = pca.inverse_transform(img_transformed) 
    temp = np.reshape(temp, (720,720)) 
    sqaures_PCA.append(temp)
    
fig = plt.figure(figsize=(10, 10))
columns = 8
rows = 8
for i, img in zip(range(1, columns*rows +1),sqaures_PCA):
    fig.add_subplot(rows, columns, i)
    plt.imshow(img, cmap='gray')
    plt.axis('off')
    plt.tight_layout()
plt.show()

3 - Feature Engineering¶

In [ ]:
def image_to_squares(img,heights,widths):
  squares = []
  for i in range(0,8):
    for j in range(0,8):
      squares.append(img[i*heights//8:i*heights//8+heights//8,j*widths//8:j*widths//8+widths//8])
  return np.array(squares)
In [ ]:
sqaures= image_to_squares(preprocessed_imgs[0],240,240)
fig = plt.figure(figsize=(10, 10))
columns = 8
rows = 8
for i, img in zip(range(1, columns*rows +1),sqaures):
    fig.add_subplot(rows, columns, i)
    plt.imshow(img, cmap='gray')
    plt.axis('off')
    plt.tight_layout()
plt.show()
sqaures.shape
Out[ ]:
(64, 30, 30, 3)
  • Create a function to the complete preprocessing and sqaures division of one sample of the data
In [ ]:
def preprocess_image(img_path):
    height =240
    width =240

    # change to Grey scal
    img = cv2.imread(img_path, cv2.COLOR_BGR2GRAY)
  
    # resize the image to the desired size
    gray_image = cv2.resize(img, (width, height))
    
    # Normalize the image
    gray_image =(gray_image - np.min(gray_image)) / (np.max(gray_image) - np.min(gray_image))

    squares = image_to_squares(gray_image,height,width)
    return squares
In [ ]:
sqaures=preprocess_image(train[444])

fig = plt.figure(figsize=(10, 10))
columns = 8
rows = 8
for i, img in zip(range(1, columns*rows +1),sqaures):
    fig.add_subplot(rows, columns, i)
    plt.imshow(img, cmap='gray')
    plt.axis('off')
    plt.tight_layout()
plt.show()
sqaures.shape
Out[ ]:
(64, 30, 30, 3)
  • Create a FEN Label Encode /Decoder Functions
  • NOTE These 2 Functions are taken from this Kaggle notebook : https://www.kaggle.com/code/koryakinp/chess-fen-generator/notebook
In [ ]:
def onehot_from_fen(fen):
    eye = np.eye(13)
    output = np.empty((0, 13))
    fen = re.sub('[-]', '', fen)

    for char in fen:
        if(char in '12345678'):
            output = np.append(
              output, np.tile(eye[12], (int(char), 1)), axis=0)
        else:
            idx = piece_symbols.index(char)
            output = np.append(output, eye[idx].reshape((1, 13)), axis=0)

    return output

def fen_from_onehot(one_hot):
    output = ''
    for j in range(8):
        for i in range(8):
            if(one_hot[j][i] == 12):
                output += ' '
            else:
                output += piece_symbols[one_hot[j][i]]
        if(j != 7):
            output += '-'

    for i in range(8, 0, -1):
        output = output.replace(' ' * i, str(i))

    return output
  • Create a function divide the data into training and testing sets
In [ ]:
def train_gen(features):
    for i, img in enumerate(features):
        y = onehot_from_fen(get_image_FEN_label(img))
        x = preprocess_image(img)
        yield x, y

def pred_gen(features):
    for i, img in enumerate(features):
        y = onehot_from_fen(get_image_FEN_label(img))
        x = preprocess_image(img)
        yield x, y

4 - Modeling & Model Training¶

In [ ]:
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D

# Add a convolutional layer
model = Sequential()
model.add(Convolution2D(32, (3, 3),activation='relu', input_shape=(30, 30, 3)))
model.add(MaxPooling2D(pool_size=(3, 3)))
model.add(Convolution2D(16, (5, 5),activation='relu'))
model.add(Flatten())
model.add(Dropout(0.35))
model.add(Dense(13, activation='softmax'))
model.summary()
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d_2 (Conv2D)           (None, 28, 28, 32)        896       
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 9, 9, 32)         0         
 2D)                                                             
                                                                 
 conv2d_3 (Conv2D)           (None, 5, 5, 16)          12816     
                                                                 
 flatten_1 (Flatten)         (None, 400)               0         
                                                                 
 dropout_1 (Dropout)         (None, 400)               0         
                                                                 
 dense_1 (Dense)             (None, 13)                5213      
                                                                 
=================================================================
Total params: 18,925
Trainable params: 18,925
Non-trainable params: 0
_________________________________________________________________
In [ ]:
# Compile the model 
model.compile(optimizer='adam', 
              loss='categorical_crossentropy', 
              metrics=['accuracy'])
In [ ]:
# Fit parameters
EPOCHS=100

print("\nTraining Progress:\n------------------------")
hist = model.fit_generator(train_gen(train), steps_per_epoch=train_size//EPOCHS, epochs=EPOCHS,
                           validation_data=pred_gen(test), validation_steps=test_size//EPOCHS)
Training Progress:
------------------------
Epoch 1/100
200/200 [==============================] - 2s 8ms/step - loss: 0.7468 - accuracy: 0.8266 - val_loss: 0.3671 - val_accuracy: 0.9023
Epoch 2/100
200/200 [==============================] - 2s 8ms/step - loss: 0.3300 - accuracy: 0.9091 - val_loss: 0.1843 - val_accuracy: 0.9523
Epoch 3/100
200/200 [==============================] - 2s 8ms/step - loss: 0.1779 - accuracy: 0.9548 - val_loss: 0.1020 - val_accuracy: 0.9762
Epoch 4/100
200/200 [==============================] - 2s 8ms/step - loss: 0.1384 - accuracy: 0.9625 - val_loss: 0.0805 - val_accuracy: 0.9863
Epoch 5/100
200/200 [==============================] - 2s 8ms/step - loss: 0.1149 - accuracy: 0.9707 - val_loss: 0.0779 - val_accuracy: 0.9805
Epoch 6/100
200/200 [==============================] - 2s 10ms/step - loss: 0.0696 - accuracy: 0.9816 - val_loss: 0.0535 - val_accuracy: 0.9848
Epoch 7/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0654 - accuracy: 0.9828 - val_loss: 0.0235 - val_accuracy: 0.9957
Epoch 8/100
200/200 [==============================] - 2s 10ms/step - loss: 0.0484 - accuracy: 0.9860 - val_loss: 0.0409 - val_accuracy: 0.9891
Epoch 9/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0446 - accuracy: 0.9866 - val_loss: 0.0199 - val_accuracy: 0.9961
Epoch 10/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0409 - accuracy: 0.9899 - val_loss: 0.0254 - val_accuracy: 0.9945
Epoch 11/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0267 - accuracy: 0.9918 - val_loss: 0.0133 - val_accuracy: 0.9965
Epoch 12/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0269 - accuracy: 0.9930 - val_loss: 0.0169 - val_accuracy: 0.9953
Epoch 13/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0226 - accuracy: 0.9931 - val_loss: 0.0039 - val_accuracy: 0.9996
Epoch 14/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0285 - accuracy: 0.9924 - val_loss: 0.0102 - val_accuracy: 0.9969
Epoch 15/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0218 - accuracy: 0.9935 - val_loss: 0.0049 - val_accuracy: 0.9992
Epoch 16/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0177 - accuracy: 0.9951 - val_loss: 0.0051 - val_accuracy: 0.9992
Epoch 17/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0121 - accuracy: 0.9968 - val_loss: 0.0031 - val_accuracy: 1.0000
Epoch 18/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0197 - accuracy: 0.9950 - val_loss: 0.0046 - val_accuracy: 1.0000
Epoch 19/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0113 - accuracy: 0.9967 - val_loss: 0.0021 - val_accuracy: 1.0000
Epoch 20/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0116 - accuracy: 0.9966 - val_loss: 0.0103 - val_accuracy: 0.9965
Epoch 21/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0115 - accuracy: 0.9960 - val_loss: 0.0018 - val_accuracy: 1.0000
Epoch 22/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0072 - accuracy: 0.9983 - val_loss: 0.0013 - val_accuracy: 1.0000
Epoch 23/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0087 - accuracy: 0.9975 - val_loss: 0.0019 - val_accuracy: 0.9996
Epoch 24/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0082 - accuracy: 0.9977 - val_loss: 0.0027 - val_accuracy: 1.0000
Epoch 25/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0397 - accuracy: 0.9920 - val_loss: 0.0040 - val_accuracy: 1.0000
Epoch 26/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0095 - accuracy: 0.9972 - val_loss: 0.0031 - val_accuracy: 0.9992
Epoch 27/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0087 - accuracy: 0.9976 - val_loss: 0.0018 - val_accuracy: 0.9996
Epoch 28/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0047 - accuracy: 0.9989 - val_loss: 2.2831e-04 - val_accuracy: 1.0000
Epoch 29/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0053 - accuracy: 0.9985 - val_loss: 8.8028e-04 - val_accuracy: 0.9996
Epoch 30/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0049 - accuracy: 0.9984 - val_loss: 0.0010 - val_accuracy: 0.9996
Epoch 31/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0033 - accuracy: 0.9995 - val_loss: 0.0011 - val_accuracy: 1.0000
Epoch 32/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0028 - accuracy: 0.9993 - val_loss: 0.0012 - val_accuracy: 1.0000
Epoch 33/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0051 - accuracy: 0.9985 - val_loss: 5.5807e-04 - val_accuracy: 1.0000
Epoch 34/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0178 - accuracy: 0.9947 - val_loss: 0.0031 - val_accuracy: 1.0000
Epoch 35/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0090 - accuracy: 0.9971 - val_loss: 0.2521 - val_accuracy: 0.9660
Epoch 36/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0176 - accuracy: 0.9963 - val_loss: 0.0016 - val_accuracy: 1.0000
Epoch 37/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0039 - accuracy: 0.9992 - val_loss: 0.0011 - val_accuracy: 1.0000
Epoch 38/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0040 - accuracy: 0.9988 - val_loss: 3.9257e-04 - val_accuracy: 1.0000
Epoch 39/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0033 - accuracy: 0.9991 - val_loss: 4.2694e-04 - val_accuracy: 1.0000
Epoch 40/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0051 - accuracy: 0.9987 - val_loss: 1.0669e-04 - val_accuracy: 1.0000
Epoch 41/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0032 - accuracy: 0.9991 - val_loss: 1.2351e-04 - val_accuracy: 1.0000
Epoch 42/100
200/200 [==============================] - 2s 9ms/step - loss: 0.0048 - accuracy: 0.9985 - val_loss: 0.0012 - val_accuracy: 0.9996
Epoch 43/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0035 - accuracy: 0.9988 - val_loss: 0.0015 - val_accuracy: 0.9996
Epoch 44/100
200/200 [==============================] - 2s 9ms/step - loss: 0.0023 - accuracy: 0.9991 - val_loss: 6.4988e-04 - val_accuracy: 1.0000
Epoch 45/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0141 - accuracy: 0.9959 - val_loss: 0.0495 - val_accuracy: 0.9844
Epoch 46/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0116 - accuracy: 0.9972 - val_loss: 2.9027e-04 - val_accuracy: 1.0000
Epoch 47/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0039 - accuracy: 0.9988 - val_loss: 3.0627e-04 - val_accuracy: 1.0000
Epoch 48/100
200/200 [==============================] - 2s 9ms/step - loss: 0.0026 - accuracy: 0.9992 - val_loss: 7.4035e-04 - val_accuracy: 0.9996
Epoch 49/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0019 - accuracy: 0.9996 - val_loss: 7.1121e-05 - val_accuracy: 1.0000
Epoch 50/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0022 - accuracy: 0.9995 - val_loss: 4.7067e-04 - val_accuracy: 0.9996
Epoch 51/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0017 - accuracy: 0.9996 - val_loss: 2.2580e-04 - val_accuracy: 1.0000
Epoch 52/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0029 - accuracy: 0.9992 - val_loss: 9.2311e-05 - val_accuracy: 1.0000
Epoch 53/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0018 - accuracy: 0.9995 - val_loss: 2.8221e-04 - val_accuracy: 1.0000
Epoch 54/100
200/200 [==============================] - 2s 9ms/step - loss: 0.0014 - accuracy: 0.9995 - val_loss: 5.4999e-04 - val_accuracy: 1.0000
Epoch 55/100
200/200 [==============================] - 3s 17ms/step - loss: 9.1892e-04 - accuracy: 0.9998 - val_loss: 2.2409e-04 - val_accuracy: 1.0000
Epoch 56/100
200/200 [==============================] - 3s 15ms/step - loss: 0.0016 - accuracy: 0.9995 - val_loss: 3.9330e-05 - val_accuracy: 1.0000
Epoch 57/100
200/200 [==============================] - 2s 11ms/step - loss: 0.0014 - accuracy: 0.9997 - val_loss: 0.0013 - val_accuracy: 0.9992
Epoch 58/100
200/200 [==============================] - 2s 10ms/step - loss: 0.0040 - accuracy: 0.9987 - val_loss: 0.0115 - val_accuracy: 0.9957
Epoch 59/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0075 - accuracy: 0.9982 - val_loss: 9.8462e-04 - val_accuracy: 0.9996
Epoch 60/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0022 - accuracy: 0.9994 - val_loss: 7.8797e-05 - val_accuracy: 1.0000
Epoch 61/100
200/200 [==============================] - 2s 9ms/step - loss: 0.0010 - accuracy: 0.9998 - val_loss: 8.4627e-05 - val_accuracy: 1.0000
Epoch 62/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0023 - accuracy: 0.9994 - val_loss: 6.6188e-05 - val_accuracy: 1.0000
Epoch 63/100
200/200 [==============================] - 2s 8ms/step - loss: 9.3395e-04 - accuracy: 0.9999 - val_loss: 3.2495e-04 - val_accuracy: 0.9996
Epoch 64/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0011 - accuracy: 0.9997 - val_loss: 5.7823e-04 - val_accuracy: 0.9996
Epoch 65/100
200/200 [==============================] - 2s 10ms/step - loss: 0.0049 - accuracy: 0.9986 - val_loss: 3.9062e-04 - val_accuracy: 1.0000
Epoch 66/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0018 - accuracy: 0.9995 - val_loss: 5.8173e-05 - val_accuracy: 1.0000
Epoch 67/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0038 - accuracy: 0.9986 - val_loss: 1.0518e-04 - val_accuracy: 1.0000
Epoch 68/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0020 - accuracy: 0.9995 - val_loss: 4.2416e-04 - val_accuracy: 1.0000
Epoch 69/100
200/200 [==============================] - 2s 9ms/step - loss: 0.0042 - accuracy: 0.9986 - val_loss: 2.1521e-04 - val_accuracy: 1.0000
Epoch 70/100
200/200 [==============================] - 3s 14ms/step - loss: 0.0025 - accuracy: 0.9991 - val_loss: 2.4780e-04 - val_accuracy: 1.0000
Epoch 71/100
200/200 [==============================] - 2s 10ms/step - loss: 6.8856e-04 - accuracy: 0.9998 - val_loss: 2.1114e-05 - val_accuracy: 1.0000
Epoch 72/100
200/200 [==============================] - 2s 10ms/step - loss: 0.0017 - accuracy: 0.9998 - val_loss: 0.0043 - val_accuracy: 0.9988
Epoch 73/100
200/200 [==============================] - 3s 14ms/step - loss: 6.7726e-04 - accuracy: 0.9998 - val_loss: 9.4120e-05 - val_accuracy: 1.0000
Epoch 74/100
200/200 [==============================] - 2s 10ms/step - loss: 0.0156 - accuracy: 0.9966 - val_loss: 2.6332e-04 - val_accuracy: 1.0000
Epoch 75/100
200/200 [==============================] - 2s 9ms/step - loss: 0.0025 - accuracy: 0.9992 - val_loss: 5.4001e-04 - val_accuracy: 0.9996
Epoch 76/100
200/200 [==============================] - 2s 8ms/step - loss: 9.3370e-04 - accuracy: 0.9998 - val_loss: 1.5566e-05 - val_accuracy: 1.0000
Epoch 77/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0026 - accuracy: 0.9989 - val_loss: 4.7303e-05 - val_accuracy: 1.0000
Epoch 78/100
200/200 [==============================] - 1s 7ms/step - loss: 8.7546e-04 - accuracy: 0.9997 - val_loss: 1.3028e-05 - val_accuracy: 1.0000
Epoch 79/100
200/200 [==============================] - 2s 8ms/step - loss: 3.8067e-04 - accuracy: 0.9999 - val_loss: 2.4801e-05 - val_accuracy: 1.0000
Epoch 80/100
200/200 [==============================] - 2s 8ms/step - loss: 5.7231e-04 - accuracy: 0.9999 - val_loss: 4.2533e-06 - val_accuracy: 1.0000
Epoch 81/100
200/200 [==============================] - 2s 8ms/step - loss: 7.1726e-04 - accuracy: 0.9998 - val_loss: 2.4702e-05 - val_accuracy: 1.0000
Epoch 82/100
200/200 [==============================] - 2s 9ms/step - loss: 9.5949e-04 - accuracy: 0.9997 - val_loss: 0.0013 - val_accuracy: 0.9996
Epoch 83/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0047 - accuracy: 0.9986 - val_loss: 1.7136e-04 - val_accuracy: 1.0000
Epoch 84/100
200/200 [==============================] - 2s 10ms/step - loss: 0.0011 - accuracy: 0.9995 - val_loss: 5.7767e-05 - val_accuracy: 1.0000
Epoch 85/100
200/200 [==============================] - 2s 9ms/step - loss: 8.7595e-04 - accuracy: 0.9998 - val_loss: 1.6236e-05 - val_accuracy: 1.0000
Epoch 86/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0011 - accuracy: 0.9997 - val_loss: 2.0376e-05 - val_accuracy: 1.0000
Epoch 87/100
200/200 [==============================] - 2s 8ms/step - loss: 9.7717e-04 - accuracy: 0.9998 - val_loss: 6.6056e-04 - val_accuracy: 0.9996
Epoch 88/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0028 - accuracy: 0.9991 - val_loss: 2.8229e-05 - val_accuracy: 1.0000
Epoch 89/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0012 - accuracy: 0.9996 - val_loss: 5.7100e-06 - val_accuracy: 1.0000
Epoch 90/100
200/200 [==============================] - 2s 9ms/step - loss: 0.0020 - accuracy: 0.9995 - val_loss: 4.9737e-06 - val_accuracy: 1.0000
Epoch 91/100
200/200 [==============================] - 2s 8ms/step - loss: 6.1796e-04 - accuracy: 0.9999 - val_loss: 1.2756e-05 - val_accuracy: 1.0000
Epoch 92/100
200/200 [==============================] - 2s 9ms/step - loss: 5.7106e-04 - accuracy: 0.9998 - val_loss: 4.2657e-06 - val_accuracy: 1.0000
Epoch 93/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0013 - accuracy: 0.9995 - val_loss: 8.3916e-05 - val_accuracy: 1.0000
Epoch 94/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0038 - accuracy: 0.9991 - val_loss: 4.3363e-05 - val_accuracy: 1.0000
Epoch 95/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0012 - accuracy: 0.9996 - val_loss: 4.5219e-05 - val_accuracy: 1.0000
Epoch 96/100
200/200 [==============================] - 2s 8ms/step - loss: 5.3631e-04 - accuracy: 0.9999 - val_loss: 1.2013e-06 - val_accuracy: 1.0000
Epoch 97/100
200/200 [==============================] - 2s 8ms/step - loss: 9.6246e-04 - accuracy: 0.9995 - val_loss: 2.4490e-04 - val_accuracy: 1.0000
Epoch 98/100
193/200 [===========================>..] - ETA: 0s - loss: 0.0013 - accuracy: 0.9997WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 40 batches). You may need to use the repeat() function when building your dataset.
200/200 [==============================] - 2s 8ms/step - loss: 0.0014 - accuracy: 0.9995 - val_loss: 1.2528e-05 - val_accuracy: 1.0000
Epoch 99/100
200/200 [==============================] - 1s 7ms/step - loss: 0.0029 - accuracy: 0.9993
Epoch 100/100
200/200 [==============================] - 2s 8ms/step - loss: 0.0011 - accuracy: 0.9998
  • Save the model
In [ ]:
model.save('../models/small_model.h5')

5 - Model Evaluation & Model Tuning¶

  • Model Evaluation - Loss
In [ ]:
fig = plt.figure(figsize=(15, 5))
plt.plot(hist.history['loss'], label='(training data)')
plt.plot(hist.history['val_loss'], label='(test data)')

plt.ylabel('value')
plt.xlabel('No. epoch')
plt.legend(loc="upper right")
plt.title('Loss')
plt.show()
  • Model Evaluation - Accuracy
In [ ]:
fig = plt.figure(figsize=(15, 5))
plt.plot(hist.history['accuracy'], label='(training data)')
plt.plot(hist.history['val_accuracy'], label='(test data)')

plt.ylabel('value')
plt.xlabel('No. epoch')
plt.legend(loc="lower right")
plt.title('Accuracy')
plt.show()
  • Prediction - Prediction
In [ ]:
res = (
  model.predict_generator(pred_gen(test), steps=test_size)
  .argmax(axis=1)
  .reshape(-1, 8, 8)
)
  • Model Evaluation - All metrics
In [ ]:
pred_fens = np.array([fen_from_onehot(one_hot) for one_hot in res])
test_fens = np.array([get_image_FEN_label(fn) for fn in test])

final_accuracy = (pred_fens == test_fens).astype(float).mean()

print("Final Accuracy: {:1.5f}%".format(final_accuracy))
Final Accuracy: 0.99900%
In [ ]:
print("\nConfusion Matrix:\n------------------------")
confusion_matrix(test_fens, pred_fens)
Confusion Matrix:
------------------------
Out[ ]:
array([[1, 0, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       [0, 0, 1, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 1, 0, 0],
       [0, 0, 0, ..., 0, 1, 0],
       [0, 0, 0, ..., 0, 0, 1]])
In [ ]:
print("Precison:", precision_score(test_fens, pred_fens, average='weighted'))
print("Recall:", recall_score(test_fens, pred_fens, average='weighted'))
print("F1 Score:", f1_score(test_fens, pred_fens, average='weighted'))
print("Accuracy:", accuracy_score(test_fens, pred_fens))
Precison: 0.999
Recall: 0.999
F1 Score: 0.999
Accuracy: 0.999

6 - Preview Predictions and Outliers¶

  • Correct predictions
In [ ]:
import matplotlib.image as mpimg
def display_with_predicted_fen(image):
    plt.figure(figsize=(5,5))
    pred = model.predict(preprocess_image(image)).argmax(axis=1).reshape(-1, 8, 8)
    fen = fen_from_onehot(pred[0])
    imgplot = plt.imshow(mpimg.imread(image))
    plt.axis('off')
    plt.title(fen)
    plt.show()
    return fen
  
In [ ]:
predicted_fen=display_with_predicted_fen(test[230])
print("predicted FEN :",predicted_fen)
board = chess.Board(predicted_fen.replace('-', '/'))
chess.svg.board(board, size=300)  
2/2 [==============================] - 0s 5ms/step
predicted FEN : 8-4K3-5N2-8-2B2k2-1p2p3-2p2r2-6R1
Out[ ]:
. . . . . . . .
. . . . K . . .
. . . . . N . .
. . . . . . . .
. . B . . k . .
. p . . p . . .
. . p . . r . .
. . . . . . R .
  • Outliers ! - Not too many :)
In [ ]:
mask = (pred_fens != test_fens)
predicted_outliers=pred_fens[mask]
outliers=test_fens[mask]
print("how many outliers are there?",len(outliers))
how many outliers are there? 4
In [ ]:
outliers[rand]
Out[ ]:
'1r6-2p4B-8-K7-8-8-8-4k3'
In [ ]:
if len(outliers)>0:
    rand = np.random.randint(0, len(outliers))
    predicted_fen=display_with_predicted_fen('../data/test/'+outliers[rand]+'.jpeg')
    print("Actual FEN: "+outliers[rand])
    print("predicted FEN :",predicted_fen)
    board = chess.Board(predicted_fen.replace('-', '/'))
    display(chess.svg.board(board, size=300))
2/2 [==============================] - 0s 3ms/step
Actual FEN: 1r6-2p4B-8-K7-8-8-8-4k3
predicted FEN : 1r6-7B-8-K7-8-8-8-4k3
. r . . . . . .
. . . . . . . B
. . . . . . . .
K . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . k . . .

7 - Test Randomly Generated samples of the data¶

  • Function Generator to generate random samples of the data
In [ ]:
import os
if not os.path.isdir('../random/'):
    os.mkdir('../random/')
In [ ]:
from cairosvg import svg2png

def generate_random_image_from_FEN(FEN,optional_path='../random/'):
    board = chess.Board(FEN)
    boardsvg = chess.svg.board(board, size=400,coordinates=False)
    FEN = FEN.replace('/', '-')
    return svg2png(bytestring=boardsvg,write_to=optional_path+FEN+'.jpeg')

generate_random_image_from_FEN('2Kp4/2kPbnR1/1p1P1p1P/4Q2q/8/8/P2p1pp1/8')
generate_random_image_from_FEN('8/8/8/8/8/8/8/8')
In [ ]:
img  = mpimg.imread('../random/2Kp4-2kPbnR1-1p1P1p1P-4Q2q-8-8-P2p1pp1-8.jpeg')
plt.imshow(img)
Out[ ]:
<matplotlib.image.AxesImage at 0x2b7ac0130>
  • Test Some Generated Samples of the data
In [ ]:
predicted_fen=display_with_predicted_fen('../random/'+'8-8-8-8-8-8-8-8'+'.jpeg')
print("Actual FEN: "+'8-8-8-8-8-8-8-8')
print("predicted FEN :",predicted_fen)
board = chess.Board(predicted_fen.replace('-', '/'))
chess.svg.board(board, size=300)  
2/2 [==============================] - 0s 4ms/step
Actual FEN: 8-8-8-8-8-8-8-8
predicted FEN : 8-8-8-8-8-8-8-8
Out[ ]:
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . . . .
In [ ]:
predicted_fen=display_with_predicted_fen('../random/'+'2Kp4-2kPbnR1-1p1P1p1P-4Q2q-8-8-P2p1pp1-8'+'.jpeg')
print("Actual FEN: "+'2Kp4-2kPbnR1-1p1P1p1P-4Q2q-8-8-P2p1pp1-8')
print("predicted FEN :",predicted_fen)
board = chess.Board(predicted_fen.replace('-', '/'))
chess.svg.board(board, size=300)  
2/2 [==============================] - 0s 9ms/step
Actual FEN: 2Kp4-2kPbnR1-1p1P1p1P-4Q2q-8-8-P2p1pp1-8
predicted FEN : 2Kp4-2kPbnR1-1p1P1p1P-4Q2q-8-8-P2p1pp1-8
Out[ ]:
. . K p . . . .
. . k P b n R .
. p . P . p . P
. . . . Q . . q
. . . . . . . .
. . . . . . . .
P . . p . p p .
. . . . . . . .

8 - Model In production¶

In [ ]:
from tensorflow import keras
#model_p = keras.models.load_model('../models/small_model.h5')
model_p = keras.models.load_model('../models/perfect_model.h5')

res = (
  model_p.predict_generator(pred_gen(test), steps=test_size)
  .argmax(axis=1)
  .reshape(-1, 8, 8)
)

pred_fens = np.array([fen_from_onehot(one_hot) for one_hot in res])
test_fens = np.array([get_image_FEN_label(fn) for fn in test])

final_accuracy = (pred_fens == test_fens).astype(float).mean()

print("Final Accuracy: {:1.5f}%".format(final_accuracy))
2022-08-04 15:55:19.806888: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
Final Accuracy: 1.00000%
In [ ]:
import matplotlib.image as mpimg
def display_with_predicted_fen(image):
    plt.figure(figsize=(5,5))
    pred = model_p.predict(preprocess_image(image)).argmax(axis=1).reshape(-1, 8, 8)
    fen = fen_from_onehot(pred[0])
    imgplot = plt.imshow(mpimg.imread(image))
    plt.axis('off')
    plt.title(fen)
    plt.show()
    return fen

predicted_fen=display_with_predicted_fen(test[230])
print("predicted FEN :",predicted_fen)
board = chess.Board(predicted_fen.replace('-', '/'))
chess.svg.board(board, size=300)  
2/2 [==============================] - 0s 6ms/step
predicted FEN : 5k2-8-3p4-1K6-N7-8-8-5r2
Out[ ]:
. . . . . k . .
. . . . . . . .
. . . p . . . .
. K . . . . . .
N . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . r . .